{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "b28dd454-0a3a-4e0f-8e74-6c4ad712b783",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "94e0fa4d-efad-41e1-9314-09a2ef12a438",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import torchxrayvision as xrv\n",
    "import sys\n",
    "import numpy as np\n",
    "import torch\n",
    "import torchvision\n",
    "import matplotlib.pyplot as plt\n",
    "import dataset_utils"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "98ed4012-c6bb-4988-802f-0c11f2cd0057",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "794438a4-856b-4eaa-b46e-c1c654a242e1",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "model = xrv.baseline_models.xinario.ViewModel()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4d608d78-ad75-466f-84f4-f9d5a4237e37",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "7abcf913-ee6b-4d6e-b8fd-d4dcd9cd058f",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['Granuloma', 'Hemidiaphragm Elevation', 'Pleural_Thickening', 'Nodule', 'Mass', 'Cardiomegaly', 'Consolidation', 'Fibrosis', 'Scoliosis', 'Fracture', 'Atelectasis', 'Emphysema', 'Effusion', 'Air Trapping', 'Aortic Atheromatosis', 'Support Devices', 'Tuberculosis', 'Pneumothorax', 'Costophrenic Angle Blunting', 'Hilar Enlargement', 'Flattened Diaphragm', 'Edema', 'Bronchiectasis', 'Infiltration', 'Tube', 'Aortic Elongation', 'Pneumonia', 'Hernia']\n"
     ]
    }
   ],
   "source": [
    "transform = torchvision.transforms.Compose([\n",
    "    xrv.datasets.XRayCenterCrop(),\n",
    "    xrv.datasets.XRayResizer(224)\n",
    "])\n",
    "d = dataset_utils.get_data('pc', views='*', transform=transform)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "7164bd40-b0c5-4dd6-b21a-50939ee5236e",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array(['PA', 'L', 'AP', 'AP Supine', 'COSTAL', 'UNK', 'EXCLUDE'],\n",
       "      dtype=object)"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "d.csv.view.unique()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "2984fdc2-3ab8-4edf-9cb2-7aca381ffe2f",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "frontal = np.where(d.csv.view == 'PA')[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "47717a84-2120-459e-ad08-df074bdadce3",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "lateral = np.where(d.csv.view == 'L')[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "342ba82d-5504-4e65-80cd-40e467bcf3a6",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[23.1546, 16.9751]]) Frontal\n",
      "tensor([[23.6190, 15.1804]]) Frontal\n",
      "tensor([[23.9368, 15.9114]]) Frontal\n",
      "tensor([[20.4266, 14.5170]]) Frontal\n",
      "tensor([[25.9273, 14.4245]]) Frontal\n",
      "tensor([[24.4080, 13.7654]]) Frontal\n",
      "tensor([[25.0222, 15.7349]]) Frontal\n",
      "tensor([[23.8637, 16.7607]]) Frontal\n",
      "tensor([[22.3303, 13.5714]]) Frontal\n",
      "tensor([[21.2553, 14.0465]]) Frontal\n"
     ]
    }
   ],
   "source": [
    "for i in frontal[:10]:\n",
    "    img = d[i]['img']\n",
    "    with torch.no_grad():\n",
    "        output = model(torch.from_numpy(img))\n",
    "        print(output, model.targets[output.argmax()])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "f9161e17-f505-4cf1-87ad-40a73e79ae21",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[17.3186, 26.7156]]) Lateral\n",
      "tensor([[15.9319, 24.5127]]) Lateral\n",
      "tensor([[20.1788, 34.1056]]) Lateral\n",
      "tensor([[20.5084, 35.7469]]) Lateral\n",
      "tensor([[20.0122, 36.1225]]) Lateral\n",
      "tensor([[20.1512, 29.6003]]) Lateral\n",
      "tensor([[21.8098, 32.7101]]) Lateral\n",
      "tensor([[18.7384, 35.3062]]) Lateral\n",
      "tensor([[19.8528, 28.8093]]) Lateral\n",
      "tensor([[20.8488, 33.3455]]) Lateral\n"
     ]
    }
   ],
   "source": [
    "for i in lateral[:10]:\n",
    "    img = d[i]['img']\n",
    "    with torch.no_grad():\n",
    "        output = model(torch.from_numpy(img))\n",
    "        print(output, model.targets[output.argmax()])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "69b9abb6-bf8c-451b-8eea-237da57dc6dc",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fdfb6d6b-3093-468e-b32a-cc3bb857994f",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b35d5965-10d4-4d1e-bfca-4df7afbd5f7d",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5b25a4c6-e41f-4f0b-8103-672902ab2d28",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
